{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "eb9a4b5a",
   "metadata": {},
   "source": [
    "# Simple Plotting\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88c7ff9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "RESULTS_PATH = \"../../your_sweep_results_path\"\n",
    "\n",
    "PLOT_ALL_SEEDS = False\n",
    "# Full sweep\n",
    "MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\", \"gpt2-xl\", \"Qwen/Qwen-1_8B\", \"Qwen/Qwen-7B\", \"Qwen/Qwen-14B\"]\n",
    "# Minimal sweep\n",
    "# MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00ca073c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_style('whitegrid')\n",
    "\n",
    "from IPython.display import display\n",
    "\n",
    "import os\n",
    "import glob\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5caa051",
   "metadata": {},
   "outputs": [],
   "source": [
    "records = []\n",
    "all_results_folders = ['/'.join(e.split('/')[:-1]) for e in glob.glob(os.path.join(RESULTS_PATH, \"**/*.results_summary.json\"), recursive=True)]\n",
    "for result_folder in set(all_results_folders):\n",
    "    config_file = os.path.join(result_folder, \"config.json\")\n",
    "    config = json.load(open(config_file, \"r\"))\n",
    "    if config[\"strong_model_size\"] not in MODELS_TO_PLOT:\n",
    "        continue\n",
    "    if 'seed' not in config:\n",
    "        config['seed'] = 0\n",
    "    result_filename = (config[\"weak_model_size\"].replace('.', '_') + \"_\" + config[\"strong_model_size\"].replace('.', '_') + \".results_summary.json\").replace('/', '_')\n",
    "    record = config.copy()\n",
    "    record.update(json.load(open(config_file.replace('config.json', result_filename))))\n",
    "    records.append(record)\n",
    "\n",
    "df = pd.DataFrame.from_records(records).sort_values(['ds_name', 'weak_model_size', 'strong_model_size'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f628577",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = df.ds_name.unique()\n",
    "for dataset in datasets:\n",
    "    cur_df = df[(df.ds_name == dataset)]\n",
    "    base_df = pd.concat([\n",
    "        pd.DataFrame.from_dict({\"strong_model_size\": cur_df['weak_model_size'].to_list(), \"accuracy\": cur_df['weak_acc'].to_list(), \"seed\": cur_df['seed'].to_list()}),\n",
    "        pd.DataFrame.from_dict({\"strong_model_size\": cur_df['strong_model_size'].to_list(), \"accuracy\": cur_df['strong_acc'].to_list(), \"seed\": cur_df['seed'].to_list()})\n",
    "    ])\n",
    "    base_accuracies = base_df.groupby('strong_model_size').agg({'accuracy': 'mean', 'seed': 'count'}).sort_values('accuracy')\n",
    "    base_accuracy_lookup = base_accuracies['accuracy'].to_dict()\n",
    "    base_accuracies = base_accuracies.reset_index()\n",
    "    base_df.reset_index(inplace=True)\n",
    "    base_df['weak_model_size'] = 'ground truth'\n",
    "    base_df['loss'] = 'xent'\n",
    "    base_df['strong_model_accuracy'] = base_df['strong_model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
    "\n",
    "    weak_to_strong = cur_df[['weak_model_size', 'strong_model_size', 'seed'] + [e for e in cur_df.columns if e.startswith('transfer_acc')]]\n",
    "    weak_to_strong = weak_to_strong.melt(id_vars=['weak_model_size', 'strong_model_size', 'seed'], var_name='loss', value_name='accuracy')\n",
    "    weak_to_strong = weak_to_strong.dropna(subset=['accuracy'])\n",
    "    weak_to_strong.reset_index(inplace=True)\n",
    "    weak_to_strong['loss'] = weak_to_strong['loss'].str.replace('transfer_acc_', '')\n",
    "    weak_to_strong['strong_model_accuracy'] = weak_to_strong['strong_model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
    "\n",
    "    # Exclude cases where the weak model is better than the strong model from PGR calculation.\n",
    "    pgr_df = cur_df[(cur_df['weak_model_size'] != cur_df['strong_model_size']) & (cur_df['strong_acc'] > cur_df['weak_acc'])]\n",
    "    pgr_df = pgr_df.melt(id_vars=[e for e in cur_df.columns if not e.startswith('transfer_acc')], var_name='loss', value_name='transfer_acc')\n",
    "    pgr_df = pgr_df.dropna(subset=['transfer_acc'])\n",
    "    pgr_df['loss'] = pgr_df['loss'].str.replace('transfer_acc_', '')\n",
    "    pgr_df['pgr'] = (pgr_df['transfer_acc'] - pgr_df['weak_acc']) / (pgr_df['strong_acc'] - pgr_df['weak_acc'])\n",
    "\n",
    "    for seed in [None] + (sorted(cur_df['seed'].unique().tolist()) if PLOT_ALL_SEEDS else []):\n",
    "        plot_df = pd.concat([base_df, weak_to_strong])\n",
    "        seed_pgr_df = pgr_df\n",
    "        if seed is not None:\n",
    "            plot_df = plot_df[plot_df['seed'] == seed]\n",
    "            # We mean across seeds, this is because sometimes the weak and strong models will have run on different hardware and therefore\n",
    "            # have slight differences. We want to average these out when filtering by seed.\n",
    "\n",
    "            seed_pgr_df = pgr_df[pgr_df['seed'] == seed]\n",
    "\n",
    "        if seed is not None or cur_df['seed'].nunique() == 1:\n",
    "            plot_df = plot_df[['strong_model_accuracy', 'weak_model_size', 'loss', 'accuracy']].groupby(['strong_model_accuracy', 'weak_model_size', 'loss']).mean().reset_index().sort_values(['loss', 'weak_model_size'], ascending=False)\n",
    "\n",
    "        print(f\"Dataset: {dataset} (seed: {seed})\")\n",
    "\n",
    "        pgr_results = seed_pgr_df.groupby(['loss']).aggregate({\"pgr\": \"median\"})\n",
    "        display(pgr_results)\n",
    "\n",
    "        palette = sns.color_palette('colorblind', n_colors=len(plot_df['weak_model_size'].unique()) - 1)\n",
    "        color_dict = {model: (\"black\" if model == 'ground truth' else palette.pop()) for model in plot_df['weak_model_size'].unique()}\n",
    "\n",
    "        sns.lineplot(data=plot_df, x='strong_model_accuracy', y='accuracy', hue='weak_model_size', style='loss', markers=True, palette=color_dict)\n",
    "        pd.plotting.table(plt.gca(), pgr_results.round(4), loc='lower right', colWidths=[0.1, 0.1], cellLoc='center', rowLoc='center')\n",
    "        plt.xticks(ticks=base_accuracies['accuracy'], labels=[f\"{e} ({base_accuracy_lookup[e]:.4f})\" for e in base_accuracies['strong_model_size']], rotation=90)\n",
    "        plt.title(f\"Dataset: {dataset} (seed: {seed})\")\n",
    "        plt.legend(loc='upper left')\n",
    "        plt.savefig(f\"{dataset.replace('/', '-')}_{seed}.png\", dpi=300, bbox_inches='tight')\n",
    "        plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "openai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
